import sys
import pickle
import pandas as pd
sys.path.append(sys.argv[4])
from signal_udfs import plot_all_signals
from matplotlib import pyplot as plt
import os
import random

input_handle = sys.argv[1]
output_handle = sys.argv[2]
dff0_traceout_handle = sys.argv[3]


#Create dict for imported raw data
names = ['Time','DF/F0']
dff0_dict = {name:[] for name in names}
names = ['Name', 'Start', 'End']
behavior_dict = {name:[] for name in names}

#Import raw data from .csv file
for line_num, line in enumerate(open(input_handle)):
	if line_num != 0:
		line = line.strip().split(',')
		for num_num, num in enumerate(line):
			if num != '':
				if num_num == 0:
					behavior_dict['Name'].append(num + ' ' + str(line_num))
				if num_num == 1:
					behavior_dict['Start'].append(float(num))
				if num_num == 2:
                                        behavior_dict['End'].append(float(num))
				if num_num == 3:
					dff0_dict['Time'].append(float(num))
				if num_num == 4:
                                        dff0_dict['DF/F0'].append(float(num))

#Get non-bout timestamps
start_stop = []

for event_num, event in enumerate(behavior_dict['Start']):
        start_stop.append([event, behavior_dict['End'][event_num]])

non_bout_ts = []

for ts_num, ts in enumerate(start_stop):
        temp_ts = []
        if ts_num == 0:
                for num in dff0_dict['Time']:
                        if num < ts[0]:
                                temp_ts.append(num)
        if ts_num == len(start_stop):
                for num in dff0_dict['Time']:
                        if num > ts[1]:
                                temp_ts.append(num)
        else:
                for num in dff0_dict['Time']:
                        if num > start_stop[ts_num-1][1] and num < ts[0]:
                                temp_ts.append(num)
        if temp_ts != [] and temp_ts[0] > 3 and temp_ts[-1] < dff0_dict['Time'][-1] - 3:
                behavior_dict['Start'].append(temp_ts[0])
                behavior_dict['End'].append(temp_ts[-1])
                behavior_dict['Name'].append('Non-Bout ' + str(ts_num + 1))

#Make random sampling timestamps
interval = int(dff0_dict['Time'][-1]/3)

random_sampling = {'Beginning': [], 'Middle': [], 'End': []}

for name in random_sampling.keys():
        if name == 'Beginning':
                numbers = list(range(4, interval, 30))
                stamp = random.sample(numbers, len(numbers))
                for num_num, num in enumerate(stamp):
                        if num_num <= 9:
                                random_sampling[name].append([num, (num + 10)])
        if name == 'Middle':
                numbers = list(range(interval, (interval*2), 30))
                stamp = random.sample(numbers, len(numbers))
                for num_num, num in enumerate(stamp):
                        if num_num <= 9:
                                random_sampling[name].append([num, (num + 10)])
        if name == 'End':
                numbers = list(range((interval*2), (interval*3), 30))
                stamp = random.sample(numbers, len(numbers))
                for num_num, num in enumerate(stamp):
                        if num_num <= 9 and (num + 10) < (dff0_dict['Time'][-1] - 10):
                                random_sampling[name].append([num, (num + 10)])                
#Add random sampling to behavior list:
for name in random_sampling.keys():
        for num_num, num in enumerate(random_sampling[name]):
                behavior_dict['Name'].append('Random Sample ' + '(' + name + ') ' + str(num_num + 1))
                behavior_dict['Start'].append(num[0])
                behavior_dict['End'].append(num[-1])

#Create dict for extracted traces
names = ['Name', 'Baseline Time', 'Baseline Trace', 'Event Time', 'Event Trace', 'Full Time', 'Full Trace', 'Baseline Time Norm', 'Event Time Norm', 'Full Time Norm']
trace_dict = {name: [] for name in names}
base_epoch = 2 #float or int for how long you want baseline and event to be.

for name_num, name in enumerate(behavior_dict['Name']):
        if name not in trace_dict['Name']:
                base_time = []
                base_norm_time = []
                base_trace = []
                event_time = []
                event_norm_time = []
                event_trace = []
                full_time = []
                full_norm_time = []
                full_trace = []
                for num_num, num in enumerate(dff0_dict['Time']):
                        if num >= behavior_dict['Start'][name_num] - base_epoch and num < behavior_dict['Start'][name_num]:
                                base_time.append(num)
                                base_norm_time.append(num - behavior_dict['Start'][name_num])
                                base_trace.append(dff0_dict['DF/F0'][num_num])
                                full_time.append(num)
                                full_norm_time.append(num - behavior_dict['Start'][name_num])
                                full_trace.append(dff0_dict['DF/F0'][num_num])
                        if num >= behavior_dict['Start'][name_num] and num < behavior_dict['Start'][name_num] + base_epoch:
                                event_time.append(num)
                                event_norm_time.append(num - behavior_dict['Start'][name_num])
                                event_trace.append(dff0_dict['DF/F0'][num_num])
                                full_time.append(num)
                                full_norm_time.append(num - behavior_dict['Start'][name_num])
                                full_trace.append(dff0_dict['DF/F0'][num_num])
                                
                trace_dict['Name'].append(name)
                trace_dict['Baseline Time'].append(base_time)
                trace_dict['Baseline Time Norm'].append(base_norm_time)
                trace_dict['Baseline Trace'].append(base_trace)
                trace_dict['Event Time'].append(event_time)
                trace_dict['Event Time Norm'].append(event_norm_time)
                trace_dict['Event Trace'].append(event_trace)
                trace_dict['Full Time'].append(full_time)
                trace_dict['Full Time Norm'].append(full_norm_time)
                trace_dict['Full Trace'].append(full_trace)

#Plot traces
for name_num, name in enumerate(trace_dict['Name']):
        plt.plot(trace_dict['Full Time Norm'][name_num], trace_dict['Full Trace'][name_num])
        plt.xlabel('Time (s)')
        plt.ylabel('DF/F0 (%)')
        plt.title(name)
        plt.savefig(dff0_traceout_handle[:-4] + '_' + name + '4s_peri-event_dff0_trace.png')
        plt.close()

#Save event trace outputs
longest_len = 0
longest_line = []
temp_time = []
current_ts = 0
for trace_num, trace in enumerate(trace_dict['Full Time Norm']):
        if trace_num == 0:
                for num_num, num in enumerate(trace):
                        if num_num > 0:
                                temp_time.append(num - trace[num_num - 1])
        if len(trace) > longest_len:
                longest_len = len(trace)
                longest_line = trace

col_names = ['Time (s)']
for name in trace_dict['Name']:
        col_names.append(name)

df_time = pd.DataFrame(longest_line)
df_trace = pd.DataFrame.from_dict(trace_dict['Full Trace']).T
df_trace_out = pd.concat([df_time, df_trace], axis = 1, ignore_index = True)
df_trace_out.columns = col_names
df_trace_out.to_csv(output_handle[:-4] + '_4s_peri-event_dff0_traces.csv', index = False)

#Pickle trace_dict
pickle.dump(trace_dict,open('trace_dict_4s_peri.pkl','wb'))
pickle.dump(dff0_dict,open('dff0_dict_4s_peri.pkl','wb'))
pickle.dump(behavior_dict,open('behavior_dict.pkl','wb'))
pickle.dump(base_epoch, open('base_epoch_4s_peri.pkl', 'wb'))
